/*
    Copyright 2006-2011 Patrik Jonsson, sunrise@familjenjonsson.org

    This file is part of Sunrise.

    Sunrise is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.

    Sunrise is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.

*/

/// \file
/// Extra mcrx functions. Functions to  read and write lost energy and 
/// random number generator states, calculate integrated SED and
/// dust-attenuated images. \ingroup mcrx

#include "mcrx.h"
#include "blitz-fits.h"
#include "biniostream.h"
#include "fits-utilities.h"
#include "interpolatort.h"
#include "constants.h"
#include "misc.h"
#include "mpi_util.h"
#include <boost/archive/binary_oarchive.hpp>
#include <boost/archive/binary_iarchive.hpp>

using namespace blitz;
using namespace CCfits;

/** Loads the lost energy variable from the keywords in a FITS HDU.  */
mcrx::T_float mcrx::Mcrx::load_lost_energy (CCfits::ExtHDU& hdu, int i) const
{
  ostringstream ost;
  ost << i;  
  T_float lost_energy= 0 ;
  T_float normalization = 0;
  try {
    hdu.readKey("L_lost" +ost.str (), lost_energy);
    hdu.readKey("normalization" +ost.str (), normalization);
  }
  catch (...) {}
  assert (lost_energy >= 0);
  return lost_energy/normalization;
}


/** Writes the lost energy variable to the specified FITS HDU as a
    keyword.  */
void mcrx::Mcrx::write_lost_energy (CCfits::ExtHDU& hdu, T_float lost_energy,
				    T_float normalization, int i) const
{
  assert (lost_energy >= 0);
  
  ostringstream ost;
  ost << i;
  // ensure that previous normalization is consistent
  T_float old_normalization = 0;
  try {
    hdu.readKey("normalization" +ost.str (), old_normalization);
  }
  catch (...) {}
  if ((normalization-old_normalization)/normalization>1e-9) {
    cerr << "Warning: Inconsistent normalization in write_lost_energy: " 
	 << normalization << " vs. " << old_normalization << endl;
  }
  T_unit_map& u =const_cast<T_unit_map&> (units);
  hdu.addKey("L_lost" +ost.str (), lost_energy*normalization, "[" +
	     u ["L_lambda"] +
	     "] Energy lost in Monte Carlo wavelength "+ost.str ());
}


/** Loads the random-number generator states from the dump file.  */
void mcrx::Mcrx::load_rng_states (binifstream& file)
{
  class bad_rng_state {};
  if (file.get() != '"')
    throw bad_rng_state();

  rng_states.clear();
  while (true) {
    string s;
    getline (file, s);
    assert(s.length()>0);

    if (file.eof())
      throw 2 ;
    
    rng_states.push_back(T_rng::T_state (s)) ;
    if (file.peek() == '"') {
      // we are done, remove ending quote from stream
      char c;
      file.get(c); 
      break ;
    }
  }

  assert (!rng_states.empty());
}


/** Writes the random number generator states to the dump file.  */
void mcrx::Mcrx::write_rng_states (binofstream& file)
{
  assert(!rng_states.empty());
  file << '"' ;
  for (vector<T_rng::T_state>:: iterator i = rng_states.begin() ;
       i != rng_states.end() ; ++i) {
    const string s= i->str();
    assert(s.length()>1);

    file << s << '\n' ;
  }
  file << '"' ;
  assert(file.good());
}


/** Loads the random-number generator states for this task from a FITS file.  */
void mcrx::Mcrx::load_rng_states (CCfits::FITS& file)
{  
  // try to open HDU
  const string HDU_name = "RANDOM_STATE";
  CCfits::ExtHDU& hdu = open_HDU (file, HDU_name) ;
  hdu.makeThisCurrent();
  // for some reason, CCfits always returns width 1 for Tbyte columns...
  long statewidth;
  int crap, status=0;
  fits_read_tdim(file.fitsPointer(), hdu.column("states").index(),
		 1, &crap, &statewidth, &status);
  
  const int n_states = hdu.column("states").rows()/mpi_size();
  assert(hdu.column("states").rows()%n_states==0);

  cout << "Loading " << n_states << " RNG states of " << statewidth << "bytes"<<endl;

  rng_states.clear(); 
  
  const int start_row = mpi_rank()*n_states+1;

  vector<char> buf(statewidth);

  for (int i=0; i<n_states; ++i) {
    fits_read_col(file.fitsPointer(), TBYTE, 
		  hdu.column("states").index(), start_row+i,
		  1, statewidth, 0, &buf[0], 0, &status);
    
    stringstream s;
    for(int l=0; l<statewidth; ++l)
      s << buf[l];

    boost::archive::binary_iarchive ar(s);
    T_rng::T_state state;
    ar >> state;
    rng_states.push_back(state);
  }  

  if (status) {
    fits_report_error (stderr, status);
    throw 1;
  }
  
  assert (!rng_states.empty());
}

/** Collects the random number states from all the tasks and writes it
    to a FITS file. */
void mcrx::Mcrx::save_rng_states(CCfits::FITS& file)
{
  if(mpi_rank()==0) {
    // make a vector of states and copy our states to it.
    vector<vector<T_rng::T_state> > statevector = 
      mpi_collect(rng_states);

    // now concatenate the states in all those vectors into one vector
    vector<T_rng::T_state> states;
    for(vector<vector<T_rng::T_state> >::iterator i=statevector.begin();
	i!=statevector.end(); ++i)
	std::copy(i->begin(), i->end(), 
		  back_inserter(states));
    // and write it to the file
    write_rng_states(file, states);
  }
  else {
    mpi_collect(rng_states);
  }
}

/** Writes a vector of random-number generator states to a FITS file.  */
void mcrx::Mcrx::write_rng_states (CCfits::FITS& file, 
				   const vector<T_rng::T_state>& states)
{
  assert(!states.empty());

  // serialize the states. we use a binary archive, because then all
  // states have the same serialized length so we can just put it in a
  // fixed-length binary column
  vector<string> state_strs;
  for (vector<T_rng::T_state>::const_iterator i = states.begin() ;
       i != states.end() ; ++i) {
    stringstream s;
    boost::archive::binary_oarchive ar(s);
    ar << *i;
    state_strs.push_back(s.str());
  }

  const int statewidth = state_strs.back().size();

  cout << "Writing " << state_strs.size() << " RNG states" << endl;

  // try to open HDU
  const string HDU_name = "RANDOM_STATE";
  CCfits::ExtHDU* hdu;
  try {
    hdu = &open_HDU (file, HDU_name) ;
  }
  catch (CCfits::FITS::NoSuchHDU&) {
    // doesn't exist, we have to create it
    hdu = file.addTable(HDU_name, 0);
    hdu->addColumn(Tbyte, "states", statewidth);
  }

  // we better make sure there's enough space in the column
  const int colwidth=hdu->column("states").width();
  //assert (colwidth==staxotewidth);

  const int start_row = 1;
  int status=0;
  for(int i=0; i<states.size(); ++i) 
    // it's such a pain to convince CCfits to do this so we use cfitsio...
    fits_write_col(file.fitsPointer(), TBYTE, 
		   hdu->column("states").index(), start_row+i,
		   1, statewidth, 
		   const_cast<char*>(state_strs[i].data()), &status);

  if (status) {
    fits_report_error (stderr, status);
    throw 1;
  }
}  


/** For monochromatic runs, calculates the ratio of the luminosity in
    the cell and the absorbed luminosity. This is then used to
    interpolate the scattering wavelengths to the full wavelength
    SED. */
void
mcrx::Mcrx::calc_deposition_fraction(CCfits::FITS& output,
				     const vector<T_float>& continuum_lambdas,
				     const array_1& wavelengths,
				     const vector<long>& entry,
				     const vector<bool>& line,
				     array_1& deposition_fraction) const
{
  // read source luminosity
  ExtHDU& iq_HDU = open_HDU (output, "INTEGRATED_QUANTITIES");
  array_1 L_lambda_total;
  read(iq_HDU.column("L_lambda"), L_lambda_total);

  // build deposition fraction interpolator
  typedef interpolator< T_float , T_float , 1> T_interpolator;
  T_interpolator di;
  di.setAxis(0, continuum_lambdas) ;
  ExtHDU& deposition_HDU = open_HDU (output, "DEPOSITION");
  vector<T_float> dep_tot;
  for (int i = 0, j = 0; i < entry.size() ; ++i) {
    ostringstream ost2;
    ost2 << i;
    T_float depi = 0;
    deposition_HDU.readKey("DEP_TOT" +ost2.str (), depi);
    dep_tot.push_back(depi/L_lambda_total (entry[i]) );
    // we only want to use the interpolator if it's not a line
    if (!line [i]) {
      assert(j<continuum_lambdas.size());
      di.setPoint(j++, log10 (dep_tot.back()));
    }
  }
  
  // build deposition fraction vector to full wavelength resolution
  for (int i = 0, j = 0 ; i < deposition_fraction.size(); ++i) {
    if ((i < entry [0]) || (i > entry.back()))
      deposition_fraction (i) = 0 ;
    else if (i == entry [j])
      // it's one of the wavelengths we ran (which may be a line),
      // so just use the actual value
      deposition_fraction (i) = dep_tot [j++];
    else {
      deposition_fraction (i) = 
	pow (10, di.interpolate(log10 (wavelengths (i)))) ;
    }
  }
}



/** For monochromatic runs, calculates the full-SED attenuated images
    by interpolating the attenuation between the wavelength
    points. This function is also surprisingly long... */
bool mcrx::Mcrx::calculate_attenuated_images ()
{
  if (terminator ()) return true;
  
  if(!is_mpi_master())
    return false;

  cout << "\n  *** POSTPROCESSING PART 2 ***\n" << endl;

  typedef interpolator<array_2, T_float, 1> T_interpolator;

  const string output_file_name = 
    word_expand (p.getValue("output_file", string ())) [0];
  auto_ptr<FITS> output;
  output.reset(new FITS (output_file_name, Read));

  // this function is thoroughly unnecessary if we are running
  // polychromatic. Hence, check that first.
  bool poly = false;
  try {
    open_HDU(*output, "SCATTERING_LAMBDAS").readKey("POLY", poly);
  }
  catch (...) {}
  if (poly) {
    output->pHDU().addKey("MCRX_PP2", true,
			  "Postproc. stage 2 unnecessary" );
    cout << "  Polychromatic run, stage unnecessary." << endl;
    return terminator ();
  }
  

  // check if previously complete
  bool reentry = false;
  try {
    output->pHDU ().readKey("MCRX_PP2", reentry);
  }
  catch (...) {}
  if (reentry) {
    cout << "  Stage already complete" << endl;
    return terminator ();
  }

  output.reset(0 );
  output.reset(new FITS (output_file_name, Write));
  ExtHDU& lambda_hdu = open_HDU (*output, "LAMBDA");
  ExtHDU& scattering_lambdas_hdu = open_HDU (*output, "SCATTERING_LAMBDAS");
  Column& cl = lambda_hdu.column("lambda" );
  //Column& cLl = lambda_hdu.column("L_lambda_total" );
  array_1 wavelengths;
  read (cl, wavelengths);
  Column& ce = scattering_lambdas_hdu.column(sllec );
  vector<long> entry;
  vector<bool> line;
  ce.read(entry, 2, ce.rows() );
  scattering_lambdas_hdu.column("line" ).read(line, 2, ce.rows ());

  ExtHDU& makegrid_HDU = open_HDU (*output, "MAKEGRID");
  T_float L_bol_tot;
  makegrid_HDU.readKey("L_bol_tot", L_bol_tot);

  vector<T_float> scattering_lambdas, continuum_lambdas;
  for (int i = 0 ; i < entry.size() ;++i) {
    scattering_lambdas.push_back( log10 (wavelengths (entry [i])));
    if (!line [i])
      continuum_lambdas.push_back(scattering_lambdas.back());
  }
  
  int n_cameras;
  open_HDU (*output, "MCRX").readKey("N_CAMERA", n_cameras);

  int i = 0;
  // See if we should resume at a specific camera
  try {
    output->pHDU().readKey("MCRX_PP2_RESUME", i );
    output->pHDU().deleteKey ("MCRX_PP2_RESUME" );    
    cout << "Resuming postprocessing with camera " << i << endl;
  }
  catch (CCfits::HDU::NoSuchKeyword& h) {}   

  for (; i < n_cameras; ++i) {
    const int j=i;
    std::ostringstream ost;
    ost << j;
    const string camera_numstr=ost.str();

    // By closing and reopening the output file we don't keep the
    // large HDU's all in memory
    output.reset(new FITS (output_file_name, Write));

    // in case we get interrupted, write keyword so we know where we are
    output->pHDU().addKey("MCRX_PP2_RESUME", j,
			  "Postprocessing interrupted at this camera" );

    CCfits::ExtHDU*scatter = &open_HDU (*output, "CAMERA" +
					camera_numstr +
					"-SCATTER");
    CCfits::ExtHDU& nonscatter = open_HDU (*output, "CAMERA" +
					   camera_numstr+
					   "-NONSCATTER");
      
    array_3 scatter_image;
    Array< float, 3> nonscatter_image;
    cout << "Reading CAMERA" << j << "-SCATTER image" << endl;
    read (*scatter, scatter_image);
    // Check for terminator
    if (terminator ()) break;

    cout << "Reading CAMERA" << j << "-NONSCATTER image" << endl;
    read (nonscatter, nonscatter_image);

    if (any ( nonscatter_image < 0))
      cout << "Warning: negative numbers found in nonscatter_image!  Minimum value: " << min (nonscatter_image) << endl;

    // Check for terminator
    if (terminator ()) break;

    // calculate attenuation data cube
    array_3 attenuation_image (scatter_image.shape());

    cout << "Calculating attenuation data cube" << endl;
    for (int i = 0; i < entry.size(); ++i) {
      attenuation_image (Range::all (), Range::all (), i) =
	scatter_image (Range::all (), Range::all (), i)/
	nonscatter_image (Range::all (), Range::all (), entry [i]);
      if (terminator ()) break;
    }

    // build attenuation interpolator
    cout << "Building attenuation interpolator" << endl;
    T_interpolator ai,si;
    ai.setAxis(0, continuum_lambdas) ;
    si.setAxis(0, continuum_lambdas) ;
    ai.initializePoints(array_2 (attenuation_image.extent(0 ),
				 attenuation_image.extent(1 ))); 
    si.initializePoints(array_2 (attenuation_image.extent(0 ),
				 attenuation_image.extent(1 ))); 
    for (int i = 0, j = 0; i < entry.size() ; ++i) {
      // test for NaN resulting from 0/0 by testing for x != x
      // (0 flux in either pixel can safely be assumed to mean 0)
      // Inf comes from pixels with only scattered flux
      const array_2 temp
	(log10 (where ( attenuation_image
			(Range::all (), Range::all (), i)!= 
			attenuation_image
			(Range::all (), Range::all (), i),
			0, attenuation_image
			(Range::all (), Range::all (), i))+
		blitz::tiny (T_float ())));
      if (!line [i]) {
	ai.setPoint(j, temp);
	si.setPoint(j++, array_2 (log10 (scatter_image (Range::all (),
							Range::all (), i) +
					 blitz::tiny (T_float ()))));
      } 
      if (terminator ()) break;
    }

    if (terminator ()) break;
      
    // create CAMERAi HDU
    ExtHDU* camera_HDU;
    cout << " Creating CAMERA" << camera_numstr
	 << " image" << endl;
    try {
      camera_HDU = &open_HDU (*output, "CAMERA" + camera_numstr) ;
    }
    catch (CCfits::FITS:: NoSuchHDU&) {
      std::vector<long> naxes;
      naxes.push_back(attenuation_image.extent(firstDim ) );
      naxes.push_back(attenuation_image.extent(secondDim ) );
      naxes.push_back(entry.back() - entry.front() + 1 );
      // use the cfitsio tile compression on these images, put each
      // wavelength slice in one tile
      int status;
      std::vector<long> tile_size (naxes);
      tile_size [2] = 1 ;
      output->setCompressionType(GZIP_1);
      output->setTileDimensions(tile_size);
      camera_HDU = output->addImage ("CAMERA" + camera_numstr,
				     FLOAT_IMG, naxes);
      output->setCompressionType(0);
	
      camera_HDU ->writeComment("This HDU contains the full data cube of the object as a function of wavelength, including the effects of dust. The third dimension of the data cube is wavelength, see HDU LAMBDA for what the wavelengths are.");
      T_float sb_factor;
      nonscatter.readKey("sb_factr", sb_factor);
      Keyword& sbk = nonscatter.keyWord("sb_factr");
      string junk;
      nonscatter.readKey("imunit", junk);
      Keyword& imuk = nonscatter.keyWord("imunit");
      camera_HDU->makeThisCurrent();
      camera_HDU->addKey("sb_factr", sb_factor, sbk.comment());
      camera_HDU->addKey("imunit", junk, imuk.comment ());
    }

    // Now calculate full wavelength scatter image by using the
    // interpolator
    cout << "Interpolating image to full wavelength resolution" << endl;
    camera_HDU->makeThisCurrent();
    const int N = 3;
    TinyVector<int, N> axis;
    for (int i = 0; i < N; ++i)
      axis [i] = camera_HDU->axis(i );
    TinyVector<long, N> start = 1; 
    TinyVector<long, N> end = axis;
    int status = 0;
    FITSUtil::MatchType<T_float> imageType;

    for (int i = 0, j = 0; i < entry.back() - entry.front() + 1; ++i) {
      const array_2 attn
	(pow (10,ai.interpolate(log10 (wavelengths (i+ entry.front())))));
      const array_2 camera_slice
	// if attenuation >> 1, scattered light dominates and in that
	// case it doesn't make sense to base the output SED on the
	// nonscatter light.  In that case we simply interpolate the
	// scattered light.  (We don't make the cut off at exactly 1
	// because there will be noise in the value around 1 for the
	// wavelengths with little attenuation.)
	(where(attn <2,
	       attn*nonscatter_image(Range::all(),
				     Range::all (), i+ entry.front()),
	       pow (10,si.interpolate(log10 (wavelengths (i+ 
							  entry.front()))))));

      // Write this slice
      Array<T_float,2> temp (camera_slice.shape(), ColumnMajorArray<2> ());
      if (i+ entry.front() == entry [j])
	// it's one of the wavelengths we ran (which may be a line),
	// so just use the scatter image directly
	temp = scatter_image (Range::all (), Range::all (), j++);
      else
	temp = camera_slice;
      start [2] = i+ 1;
      end [2] = i+ 1;
      fits_write_subset (camera_HDU->fitsPointer(),imageType(),
			 &start [0], &end [0], temp.dataFirst (), &status);

      if (terminator ()) break;
    }
 
    if (status) {
      fits_report_error (stderr, status);
      throw 1;
    }

    // Check for terminator
    if (terminator ()) break;

  } // for camera

  if (terminator ()) {
    // we were terminated in the camera loop.  Keyword is already
    // written, so we don't need to do that.
    cout << "Postprocessing terminated" << endl;
  }
  else {
    // we are done, write keyword to inform of this fact
    output->pHDU().addKey("MCRX_PP2", true,
			  "Postprocessing stage 2 complete" );
    output->pHDU().deleteKey ("MCRX_PP2_RESUME" );    
    cout << "Postprocessing stage 2 complete" << endl;
  }
  
  return terminator ();
}

